{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7XNMxdTwURqI"
   },
   "source": [
    "# External Callbacks in JAX\n",
    "\n",
    "<!--* freshness: { reviewed: '2024-04-08' } *-->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "h6lXo6bSUYGq"
   },
   "source": [
    "This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Xi_nhfpnlmbm"
   },
   "source": [
    "## Why callbacks?\n",
    "\n",
    "A callback routine is a way to perform **host-side** execution of code at runtime.\n",
    "As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation.\n",
    "Using a simple Python `print` statement, it looks like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "lz8rEL1Amb4r",
    "outputId": "bbd37102-19f2-46d2-b794-3d4952c6fe97"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "\n",
    "@jax.jit\n",
    "def f(x):\n",
    "  y = x + 1\n",
    "  print(\"intermediate value: {}\".format(y))\n",
    "  return y * 2\n",
    "\n",
    "result = f(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yEy41sFAmxOp"
   },
   "source": [
    "What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).\n",
    "\n",
    "To print the value at runtime we need a callback, for example `jax.debug.print`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "wFfHmoQxnKDF",
    "outputId": "6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "intermediate value: 3\n"
     ]
    }
   ],
   "source": [
    "@jax.jit\n",
    "def f(x):\n",
    "  y = x + 1\n",
    "  jax.debug.print(\"intermediate value: {}\", y)\n",
    "  return y * 2\n",
    "\n",
    "result = f(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CvWv3pudn9X5"
   },
   "source": [
    "This works by passing the runtime value represented by `y` back to the host process, where the host can print the value."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "X0vR078znuT-"
   },
   "source": [
    "## Flavors of Callback\n",
    "\n",
    "In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:\n",
    "\n",
    "- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.\n",
    "- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.\n",
    "- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.\n",
    "\n",
    "(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).\n",
    "\n",
    "From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.\n",
    "\n",
    "|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |\n",
    "|-------------------------------------|----|----|----|----|----|----|\n",
    "|`jax.pure_callback`            | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ |\n",
    "|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ |\n",
    "|`jax.debug.callback`           | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |\n",
    "\n",
    "¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff\n",
    "\n",
    "² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`.\n",
    "\n",
    "³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hE_M8DaPvoym"
   },
   "source": [
    "### Exploring `jax.pure_callback`\n",
    "\n",
    "`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).\n",
    "\n",
    "The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "4lQDzXy6t_-k",
    "outputId": "279e4daf-0540-4eab-f535-d3bcbac74c44"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "def f_host(x):\n",
    "  # call a numpy (not jax.numpy) operation:\n",
    "  return np.sin(x).astype(x.dtype)\n",
    "\n",
    "def f(x):\n",
    "  result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
    "  return jax.pure_callback(f_host, result_shape, x)\n",
    "\n",
    "x = jnp.arange(5.0)\n",
    "f(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "q7YCIr8qMrDs"
   },
   "source": [
    "Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "bgoZ0fxsuoWV",
    "outputId": "901443bd-5cb4-4923-ce53-6f832ac22ca9"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.jit(f)(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "ajBRGWGfupu2",
    "outputId": "b28e31ee-7457-4b92-872b-52d819f53ddf"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.vmap(f)(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "xe7AOGexvC13",
    "outputId": "8fa77977-1f2b-41c5-cc5e-11993ee5aa3e"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def body_fun(_, x):\n",
    "  return _, f(x)\n",
    "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tMzAVs2VNj5G"
   },
   "source": [
    "However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "4QAF4VhUu5bb",
    "outputId": "f8a06d02-47e9-4240-8077-d7be81e5a480"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exception reporting mode: Minimal\n"
     ]
    }
   ],
   "source": [
    "%xmode minimal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "qUpKPxlOurfY",
    "outputId": "11a665e8-40eb-4b0e-dc2e-a544a25fc57e",
    "tags": [
     "raises-exception"
    ]
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "ignored",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n"
     ]
    }
   ],
   "source": [
    "jax.grad(f)(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y9DAibV4Nwpo"
   },
   "source": [
    "For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LrvdAloMZbIe"
   },
   "source": [
    "By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "mmFc_zawZrBq",
    "outputId": "a4df7568-3f64-4b2f-9a2c-7adb2e0815e0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "printing something\n"
     ]
    }
   ],
   "source": [
    "def print_something():\n",
    "  print('printing something')\n",
    "  return np.int32(0)\n",
    "\n",
    "@jax.jit\n",
    "def f1():\n",
    "  return jax.pure_callback(print_something, np.int32(0))\n",
    "f1();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "tTwE4kpmaNei"
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def f2():\n",
    "  jax.pure_callback(print_something, np.int32(0))\n",
    "  return 1.0\n",
    "f2();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qfyGYbw4Z5U3"
   },
   "source": [
    "In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.\n",
    "In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JHcJybr7OEBM"
   },
   "source": [
    "### Exploring `jax.experimental.io_callback`\n",
    "\n",
    "In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.\n",
    "\n",
    "As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "eAg5xIhrOiWV",
    "outputId": "e3cfec21-d843-4852-a49d-69a69fba9fc1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating float32[5]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from jax.experimental import io_callback\n",
    "from functools import partial\n",
    "\n",
    "global_rng = np.random.default_rng(0)\n",
    "\n",
    "def host_side_random_like(x):\n",
    "  \"\"\"Generate a random array like x using the global_rng state\"\"\"\n",
    "  # We have two side-effects here:\n",
    "  # - printing the shape and dtype\n",
    "  # - calling global_rng, thus updating its state\n",
    "  print(f'generating {x.dtype}{list(x.shape)}')\n",
    "  return global_rng.uniform(size=x.shape).astype(x.dtype)\n",
    "\n",
    "@jax.jit\n",
    "def numpy_random_like(x):\n",
    "  return io_callback(host_side_random_like, x, x)\n",
    "\n",
    "x = jnp.zeros(5)\n",
    "numpy_random_like(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mAIF31MlXj33"
   },
   "source": [
    "The `io_callback` is compatible with `vmap` by default:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "NY3o5dG6Vg6u",
    "outputId": "a67a8a98-214e-40ca-ad98-a930cd3db85e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating float32[]\n",
      "generating float32[]\n",
      "generating float32[]\n",
      "generating float32[]\n",
      "generating float32[]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.vmap(numpy_random_like)(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XXvSeeOXXquZ"
   },
   "source": [
    "Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.\n",
    "\n",
    "If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "3aNmRsDrX3-2",
    "outputId": "a8ff4b77-f4cb-442f-8cfb-ea7251c66274",
    "tags": [
     "raises-exception"
    ]
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "ignored",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: Cannot `vmap` ordered IO callback.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Cannot `vmap` ordered IO callback.\n"
     ]
    }
   ],
   "source": [
    "@jax.jit\n",
    "def numpy_random_like_ordered(x):\n",
    "  return io_callback(host_side_random_like, x, x, ordered=True)\n",
    "\n",
    "jax.vmap(numpy_random_like_ordered)(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fD2FTHlUYAZH"
   },
   "source": [
    "On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "lMVzZlIEWL7F",
    "outputId": "f9741c18-a30d-4d46-b706-8102849286b5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating float32[]\n",
      "generating float32[]\n",
      "generating float32[]\n",
      "generating float32[]\n",
      "generating float32[]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def body_fun(_, x):\n",
    "  return _, numpy_random_like_ordered(x)\n",
    "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "w_sf8mCbbo8K"
   },
   "source": [
    "Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "Cn6_RG4JcKZm",
    "outputId": "336ae5d2-e35b-4fe5-cbfb-14a7aef28c07",
    "tags": [
     "raises-exception"
    ]
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "ignored",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: IO callbacks do not support JVP.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m IO callbacks do not support JVP.\n"
     ]
    }
   ],
   "source": [
    "jax.grad(numpy_random_like)(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "plvfn9lWcKu4"
   },
   "source": [
    "However, if the callback is not dependent on a differentiated variable, it will execute:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "wxgfDmDfb5bx",
    "outputId": "d8c0285c-cd04-4b4d-d15a-1b07f778882d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hello\n"
     ]
    }
   ],
   "source": [
    "@jax.jit\n",
    "def f(x):\n",
    "  io_callback(lambda: print('hello'), None)\n",
    "  return x\n",
    "\n",
    "jax.grad(f)(1.0);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "STLI40EZcVIY"
   },
   "source": [
    "Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pkkM1ZmqclV-"
   },
   "source": [
    "### Exploring `debug.callback`\n",
    "\n",
    "Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "74TdWyu9eqBa",
    "outputId": "d8551dab-2e61-492e-9ac3-dc3db51b2c18"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log: 1.0\n"
     ]
    }
   ],
   "source": [
    "from jax import debug\n",
    "\n",
    "def log_value(x):\n",
    "  # This could be an actual logging call; we'll use\n",
    "  # print() for demonstration\n",
    "  print(\"log:\", x)\n",
    "\n",
    "@jax.jit\n",
    "def f(x):\n",
    "  debug.callback(log_value, x)\n",
    "  return x\n",
    "\n",
    "f(1.0);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P848STlsfzmW"
   },
   "source": [
    "The debug callback is compatible with `vmap`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "2sSNsPB-fGVI",
    "outputId": "fff58575-d94c-48fb-b88a-c1c395595fd0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log: 0.0\n",
      "log: 1.0\n",
      "log: 2.0\n",
      "log: 3.0\n",
      "log: 4.0\n"
     ]
    }
   ],
   "source": [
    "x = jnp.arange(5.0)\n",
    "jax.vmap(f)(x);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VDMacqpXf3La"
   },
   "source": [
    "And is also compatible with `grad` and other autodiff transformations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "id": "wkFRle-tfTDe",
    "outputId": "4e8a81d0-5012-4c51-d843-3fbdc498df31"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log: 1.0\n"
     ]
    }
   ],
   "source": [
    "jax.grad(f)(1.0);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "w8t-SDZ3gRzE"
   },
   "source": [
    "This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dF7hoWGQUneJ"
   },
   "source": [
    "## Example: `pure_callback` with `custom_jvp`\n",
    "\n",
    "One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).\n",
    "Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.\n",
    "\n",
    "Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.\n",
    "We can start by defining a straightforward `pure_callback`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ge4fNPZdVSJY"
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import scipy.special\n",
    "\n",
    "def jv(v, z):\n",
    "  v, z = jnp.asarray(v), jnp.asarray(z)\n",
    "\n",
    "  # Require the order v to be integer type: this simplifies\n",
    "  # the JVP rule below.\n",
    "  assert jnp.issubdtype(v.dtype, jnp.integer)\n",
    "\n",
    "  # Promote the input to inexact (float/complex).\n",
    "  # Note that jnp.result_type() accounts for the enable_x64 flag.\n",
    "  z = z.astype(jnp.result_type(float, z.dtype))\n",
    "\n",
    "  # Wrap scipy function to return the expected dtype.\n",
    "  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)\n",
    "\n",
    "  # Define the expected shape & dtype of output.\n",
    "  result_shape_dtype = jax.ShapeDtypeStruct(\n",
    "      shape=jnp.broadcast_shapes(v.shape, z.shape),\n",
    "      dtype=z.dtype)\n",
    "\n",
    "  # We use vectorize=True because scipy.special.jv handles broadcasted inputs.\n",
    "  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vyjQj-0QVuoN"
   },
   "source": [
    "This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "f4e46670f4e4"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "j1 = partial(jv, 1)\n",
    "z = jnp.arange(5.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6svImqFHWBwj",
    "outputId": "bc8c778a-6c10-443b-9be2-c0f28e2ac1a9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]\n"
     ]
    }
   ],
   "source": [
    "print(j1(z))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d48eb4f2d48e"
   },
   "source": [
    "Here is the same result with `jit`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "txvRqR9DWGdC",
    "outputId": "d25f3476-23b1-48e4-dda1-3c06d32c3b87"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]\n"
     ]
    }
   ],
   "source": [
    "print(jax.jit(j1)(z))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d861a472d861"
   },
   "source": [
    "And here is the same result again with `vmap`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BS-Ve5u_WU0C",
    "outputId": "08cecd1f-6953-4853-e9db-25a03eb5b000"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]\n"
     ]
    }
   ],
   "source": [
    "print(jax.vmap(j1)(z))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SCH2ii_dWXP6"
   },
   "source": [
    "However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "q3qh_4DrWxdQ",
    "outputId": "c46b0bfa-96f3-4629-b9af-a4d4f3ccb870",
    "tags": [
     "raises-exception"
    ]
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "ignored",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mUnfilteredStackTrace\u001b[0m                      Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-fde6421013cd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    161\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    163\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1090\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1091\u001b[0;31m     \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1092\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    161\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    163\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1166\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1167\u001b[0;31m       \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreduce_axes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1168\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_vjp\u001b[0;34m(fun, has_aux, reduce_axes, *primals)\u001b[0m\n\u001b[1;32m   2655\u001b[0m     \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun_nokwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2656\u001b[0;31m     out_primal, out_vjp = ad.vjp(\n\u001b[0m\u001b[1;32m   2657\u001b[0m         flat_fun, primals_flat, reduce_axes=reduce_axes)\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux, reduce_axes)\u001b[0m\n\u001b[1;32m    134\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m     \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    136\u001b[0m   \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m    123\u001b[0m   \u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m   \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    125\u001b[0m   \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tangents_pvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    313\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    315\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_nounits\u001b[0;34m(fun, pvals, instantiate)\u001b[0m\n\u001b[1;32m    766\u001b[0m     \u001b[0mfun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstantiate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 767\u001b[0;31m     \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    768\u001b[0m     \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    166\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m       \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    168\u001b[0m     \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-1-9b5b54cddb29>\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m     24\u001b[0m   \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, *args, **kwargs)\u001b[0m\n\u001b[1;32m   3425\u001b[0m   \"\"\"\n\u001b[0;32m-> 3426\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mjcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtypes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m    130\u001b[0m   \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m   out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m    132\u001b[0m       \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m    328\u001b[0m             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args\n\u001b[0;32m--> 329\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfind_top_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m    331\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m     \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_primitive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    333\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mfull_lower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mprocess_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m    309\u001b[0m       \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 310\u001b[0;31m     \u001b[0mprimal_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangent_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimals_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangents_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    311\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m     54\u001b[0m   \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m   raise ValueError(\n\u001b[0m\u001b[1;32m     56\u001b[0m       \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mUnfilteredStackTrace\u001b[0m: ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-fde6421013cd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-1-9b5b54cddb29>\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m   \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m    129\u001b[0m       lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)\n\u001b[1;32m    130\u001b[0m   \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m   out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m    132\u001b[0m       \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m       result_avals=tuple(flat_result_avals), vectorized=vectorized)\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m     53\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpure_callback_jvp_rule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m   \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m   raise ValueError(\n\u001b[0m\u001b[1;32m     56\u001b[0m       \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m       \"Please use `jax.custom_jvp` to use callbacks while taking gradients.\")\n",
      "\u001b[0;31mValueError\u001b[0m: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients."
     ]
    }
   ],
   "source": [
    "jax.grad(j1)(z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PtYeJ_xUW09v"
   },
   "source": [
    "Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:\n",
    "\n",
    "$$\n",
    "d J_\\nu(z) = \\left\\{\n",
    "\\begin{eqnarray}\n",
    "-J_1(z),\\ &\\nu=0\\\\\n",
    "[J_{\\nu - 1}(z) - J_{\\nu + 1}(z)]/2,\\ &\\nu\\ne 0\n",
    "\\end{eqnarray}\\right.\n",
    "$$\n",
    "\n",
    "The gradient with respect to $\\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.\n",
    "\n",
    "We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BOVQnt05XvLs"
   },
   "outputs": [],
   "source": [
    "jv = jax.custom_jvp(jv)\n",
    "\n",
    "@jv.defjvp\n",
    "def _jv_jvp(primals, tangents):\n",
    "  v, z = primals\n",
    "  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.\n",
    "  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)\n",
    "  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))\n",
    "  return jv(v, z), z_dot * djv_dz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W1SxcvQSX44c"
   },
   "source": [
    "Now computing the gradient of our function will work correctly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sCGceBs-X8nL",
    "outputId": "71c5589f-f996-44a0-f09a-ca8bb40c167a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.06447162\n"
     ]
    }
   ],
   "source": [
    "j1 = partial(jv, 1)\n",
    "print(jax.grad(j1)(2.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gWQ4phN5YB26"
   },
   "source": [
    "Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QTe5mRAvYQBh",
    "outputId": "d58ecff3-9419-422a-fd0e-14a7d9cf2cc3"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(-0.4003078, dtype=float32, weak_type=True)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.hessian(j1)(2.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QEXGxU4uYZii"
   },
   "source": [
    "Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.\n",
    "When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.\n",
    "However, if you are running JAX on a single CPU (where the \"host\" and \"device\" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
